Dataset Description : The mutagenic effect has a close relationship with the carcinogenicity. Nowadays, the most widely used assay for testing the mutagenicity of compounds is the Ames experiment which was invented by a professor named Ames. The Ames test is a short-term bacterial reverse mutation assay detecting a large number of compounds which can induce genetic damage and frameshift mutations. The dataset is aggregated from four papers.
Task Description: Binary classification. Given a drug SMILES string, predict whether it is mutagenic (1) or not mutagenic (0).
from forgebox.imports import *
from gc_utils.config import ObjDict
from plotly import express as px
from transformers import AutoModel, AutoTokenizer
import pytorch_lightning as pl
from gc_utils import DBs
config = ObjDict(
pretrained = "seyonec/ChemBERTA_PubChem1M_shard00_155k",
bs = 32,
versions={
pl.__name__:pl.__version__,
torch.__name__:torch.__version__
}
)
config
from tdc.single_pred import Tox
data = Tox(name = 'AMES')
split = data.get_split()
def get_split(split):
"""
return train, valid, test dataframe
"""
return split["train"] ,split["valid"] ,split["test"]
train, valid, test = get_split(split)
for df in [train, valid, test]:
df["sm_len"] = df.Drug.apply(len)
train
train.vc("Y")
valid.vc("Y")
test.vc("Y")
Although the tokenized sequence will be much shorter
px.histogram(train, x="sm_len")
px.histogram(valid, x="sm_len")
px.histogram(test, x="sm_len")
config.y_mean, config.y_std = train.Y.mean(), train.Y.std()
config.y_mean, config.y_std
We're downloading from this pretrained model, which is based on the paper ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction:
config.pretrained
tokenizer = AutoTokenizer.from_pretrained(config.pretrained, use_fast=True)
model = AutoModel.from_pretrained(config.pretrained)
class ToxDataSet(Dataset):
def __init__(self, df, tokenizer):
self.df = df
self.index_col = df.index
self.tokenizer = tokenizer
def __len__(self):
return len(self.df)
@staticmethod
def normalize_y(y):
return (y - config.y_mean) / config.y_std
@staticmethod
def denormalzie_y(y):
return (y * config.y_std) + config.y_mean
def __getitem__(self, idx):
row = self.df.loc[self.index_col[idx]]
smiles = row["Drug"]
y = row["Y"]
# return x,y tuple
return smiles, y
def collate_fn(self, rows):
x,y = zip(*rows)
return self.tokenizer(list(x), return_tensors="pt", padding="longest")['input_ids'],\
torch.FloatTensor(list(y))[:, None]
def __repr__(self):
return f"""ToxDataSet:{len(self.df)} rows
with tokenizer:{self.tokenizer.name_or_path}"""
train_ds = ToxDataSet(pd.concat([train, valid]).reset_index(drop=True), tokenizer)
valid_ds = ToxDataSet(test, tokenizer)
train_ds, valid_ds
This is what x and y looks, intuitively
train_ds[5]
pl.__version__
class ToxLDM(pl.LightningDataModule):
def __init__(self):
super().__init__()
def train_dataloader(self):
return DataLoader(
dataset=train_ds,
batch_size=config.bs,
collate_fn=train_ds.collate_fn,
shuffle=True)
def val_dataloader(self):
return DataLoader(
dataset=valid_ds, batch_size=config.bs*4,
collate_fn=valid_ds.collate_fn, shuffle=False)
tox_ldm = ToxLDM()
x,y = next(iter(tox_ldm.train_dataloader()))
x.shape,y.shape
class ToxLightning(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.base = model
self.top = nn.Sequential(
nn.BatchNorm1d(model.config.hidden_size),
nn.Linear(model.config.hidden_size, 1)
)
self.sigmoid = nn.Sigmoid()
self.crit = nn.BCEWithLogitsLoss()
self.acc = pl.metrics.Accuracy()
self.prec = pl.metrics.Precision(num_classes=1)
self.rec = pl.metrics.Recall(num_classes=1)
def configure_optimizers(self):
"""
2 optimizers, 1 for base model, 1 for top layer
"""
base_opt = torch.optim.Adam(self.base.parameters(), lr=1e-6)
top_opt = torch.optim.Adam(self.top.parameters(), lr=1e-3)
return base_opt, top_opt
def forward(self, x):
cls_vec = self.base(x).pooler_output
return self.top(cls_vec)
def training_step(self, batch, batch_idx, optimizer_idx):
x, y = batch
y_ = self(x)
loss = self.crit(y_, y)
logits = self.sigmoid(y_)
acc = self.acc(logits, y)
precision = self.prec(logits, y)
recall = self.rec(logits, y)
self.log("train_loss", loss)
self.log("train_acc", acc)
self.log("train_prec", precision)
self.log("train_recall", recall)
return loss
def validation_step(self,batch, batch_idx):
x, y = batch
y_ = self(x)
loss = self.crit(y_, y)
logits = self.sigmoid(y_)
acc = self.acc(logits, y)
precision = self.prec(logits, y)
recall = self.rec(logits, y)
self.log("val_loss", loss)
self.log("val_acc", acc)
self.log("val_prec", precision)
self.log("val_recall", recall)
return loss
pl_model = ToxLightning(model)
logger = pl.loggers.TensorBoardLogger("/GCI/tensorboard/tox", log_graph=True )
early_stopping = pl.callbacks.EarlyStopping(monitor="val_loss")
trainer = pl.Trainer(
logger,
gpus=1,
callbacks=[early_stopping],
fast_dev_run=False)
trainer.fit(pl_model,datamodule=tox_ldm)
pl_model = pl_model.eval()
def infer_smiles(x):
x = tokenizer(x, return_tensors="pt")['input_ids']
with torch.no_grad():
y_ = pl_model.sigmoid(pl_model(x))
return y_[0].item()
infer_smiles("CCOP(=S)(CC)Sc1ccccc1")
kegg_db = DBs("kegg")
with kegg_db.con.connect() as conn:
ckb_kegg_drug_match = pd.read_sql("ckb_kegg_drug_match", con = conn)
ckb_kegg_drug_match["ames_score"] = ckb_kegg_drug_match.smiles.apply(
lambda x:infer_smiles(x) if x else x)
px.histogram(ckb_kegg_drug_match, x = "ames_score")
with kegg_db.con.connect() as conn:
ckb_kegg_drug_match.to_sql(
"ckb_kegg_drug_match",
if_exists="replace",
index=False,
con=conn
)
kegg_db = DBs("kegg")
with kegg_db.con.connect() as conn:
ckb_kegg_drug_match = pd.read_sql("ckb_kegg_drug_match", con = conn)
ckb_kegg_drug_match.sample(20)
px.histogram(ckb_kegg_drug_match.query("smiles==smiles"), x="ames_score")